{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a400d67a",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aee6c148",
   "metadata": {},
   "source": [
    "# Autodiff and Backpropagation\n",
    "\n",
    "## Jacobian\n",
    "\n",
    "Let ${\\bf f}:\\mathbb{R}^n\\to \\mathbb{R}^m$, we define its Jacobian as:\n",
    "\\begin{align*}\n",
    "\\newcommand{\\bbx}{{\\bf x}}\n",
    "\\newcommand{\\bbv}{{\\bf v}}\n",
    "\\newcommand{\\bbw}{{\\bf w}}\n",
    "\\newcommand{\\bbu}{{\\bf u}}\n",
    "\\newcommand{\\bbf}{{\\bf f}}\n",
    "\\newcommand{\\bbg}{{\\bf g}}\n",
    "\\frac{\\partial \\bbf}{\\partial \\bbx} = J_{\\bbf}(\\bbx) &= \\left( \\begin{array}{ccc}\n",
    "\\frac{\\partial f_1}{\\partial x_1}&\\dots& \\frac{\\partial f_1}{\\partial x_n}\\\\\n",
    "\\vdots&&\\vdots\\\\\n",
    "\\frac{\\partial f_m}{\\partial x_1}&\\dots& \\frac{\\partial f_m}{\\partial x_n}\n",
    "\\end{array}\\right)\\\\\n",
    "&=\\left( \\frac{\\partial \\bbf}{\\partial x_1},\\dots, \\frac{\\partial \\bbf}{\\partial x_n}\\right)\\\\\n",
    "&=\\left(\n",
    "\\begin{array}{c}\n",
    "\\nabla f_1(\\bbx)^T\\\\\n",
    "\\vdots\\\\\n",
    "\\nabla f_m(x)^T\n",
    "\\end{array}\\right)\n",
    "\\end{align*}\n",
    "\n",
    "Hence the Jacobian $J_{\\bbf}(\\bbx)\\in \\mathbb{R}^{m\\times n}$ is a linear map from $\\mathbb{R}^n$ to $\\mathbb{R}^m$ such that for $\\bbx,\\bbv \\in \\mathbb{R}^n$ and $h\\in \\mathbb{R}$:\n",
    "\\begin{align*}\n",
    "\\bbf(\\bbx+h\\bbv) = \\bbf(\\bbx) + hJ_{\\bbf}(\\bbx)\\bbv +o(h).\n",
    "\\end{align*}\n",
    "The term $J_{\\bbf}(\\bbx)\\bbv\\in \\mathbb{R}^m$ is a Jacobian Vector Product (**JVP**), correponding to the interpretation where the Jacobian is the linear map: $J_{\\bbf}(\\bbx):\\mathbb{R}^n \\to \\mathbb{R}^m$, where $J_{\\bbf}(\\bbx)(\\bbv)=J_{\\bbf}(\\bbx)\\bbv$.\n",
    "\n",
    "## Chain composition\n",
    "\n",
    "In machine learning, we are computing gradient of the loss function with respect to the parameters. In particular, if the parameters are high-dimensional, the loss is a real number. Hence, consider a real-valued function $\\bbf:\\mathbb{R}^n\\stackrel{\\bbg_1}{\\to}\\mathbb{R}^m \\stackrel{\\bbg_2}{\\to}\\mathbb{R}^d\\stackrel{h}{\\to}\\mathbb{R}$, so that $\\bbf(\\bbx) = h(\\bbg_2(\\bbg_1(\\bbx)))\\in \\mathbb{R}$. We have\n",
    "\\begin{align*}\n",
    "\\underbrace{\\nabla\\bbf(\\bbx)}_{n\\times 1}=\\underbrace{J_{\\bbg_1}(\\bbx)^T}_{n\\times m}\\underbrace{J_{\\bbg_2}(\\bbg_1(\\bbx))^T}_{m\\times d}\\underbrace{\\nabla h(\\bbg_2(\\bbg_1(\\bbx)))}_{d\\times 1}.\n",
    "\\end{align*}\n",
    "To do this computation, if we start from the right so that we start with a matrix times a vector to obtain a vector (of size $m$) and we need to make another matrix times a vector, resulting in $O(nm+md)$ operations. If we start from the left with the matrix-matrix multiplication, we get $O(nmd+nd)$ operations. Hence we see that as soon as $m\\approx d$, starting for the right is much more efficient. Note however that doing the computation from the right to the left requires to keep in memory the values of $\\bbg_1(\\bbx)\\in\\mathbb{R}^m$, and $\\bbx\\in \\mathbb{R}^n$.\n",
    "\n",
    "**Backpropagation** is an efficient algorithm computing the gradient \"from the right to the left\", i.e. backward. In particular, we will need to compute quantities of the form: $J_{\\bbf}(\\bbx)^T\\bbu \\in \\mathbb{R}^n$ with $\\bbu \\in\\mathbb{R}^m$ which can be rewritten $\\bbu^T J_{\\bbf}(\\bbx)$ which is a Vector Jacobian Product (**VJP**), correponding to the interpretation where the Jacobian is the linear map: $J_{\\bbf}(\\bbx):\\mathbb{R}^n \\to \\mathbb{R}^m$, composed with the linear map $\\bbu:\\mathbb{R}^m\\to \\mathbb{R}$ so that $\\bbu^TJ_{\\bbf}(\\bbx) = \\bbu \\circ J_{\\bbf}(\\bbx)$.\n",
    "\n",
    "**example:** let $\\bbf(\\bbx, W) = \\bbx W\\in \\mathbb{R}^b$ where $W\\in \\mathbb{R}^{a\\times b}$ and $\\bbx\\in \\mathbb{R}^a$. We clearly have\n",
    "$$\n",
    "J_{\\bbf}(\\bbx) = W^T.\n",
    "$$\n",
    "Note that here, we are slightly abusing notations and considering the partial function $\\bbx\\mapsto \\bbf(\\bbx, W)$. To see this, we can write $f_j = \\sum_{i}x_iW_{ij}$ so that \n",
    "$$\n",
    "\\frac{\\partial \\bbf}{\\partial x_i}= \\left( W_{i1}\\dots W_{ib}\\right)^T\n",
    "$$\n",
    "Then recall from definitions that\n",
    "$$\n",
    "J_{\\bbf}(\\bbx) = \\left( \\frac{\\partial \\bbf}{\\partial x_1},\\dots, \\frac{\\partial \\bbf}{\\partial x_n}\\right)=W^T.\n",
    "$$\n",
    "Now we clearly have\n",
    "$$\n",
    "J_{\\bbf}(W) = \\bbx \\text{ since, } \\bbf(\\bbx,W+\\Delta W) = \\bbf(\\bbx,W) + \\bbx \\Delta W.\n",
    "$$\n",
    "Note that multiplying $\\bbx$ on the right is actually convenient when using broadcasting, i.e. we can take a batch of input vectors of shape $\\text{bs}\\times a$ without modifying the math above. \n",
    "\n",
    "## Implementation\n",
    "\n",
    "In PyTorch, `torch.autograd` provides classes and functions implementing automatic differentiation of arbitrary scalar valued functions. To create a custom [autograd.Function](https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function), subclass this class and implement the `forward()` and `backward()` static methods. Here is an example:\n",
    "```python=\n",
    "class Exp(Function):\n",
    "    @staticmethod\n",
    "    def forward(ctx, i):\n",
    "        result = i.exp()\n",
    "        ctx.save_for_backward(result)\n",
    "        return result\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):\n",
    "        result, = ctx.saved_tensors\n",
    "        return grad_output * result\n",
    "# Use it by calling the apply method:\n",
    "output = Exp.apply(input)\n",
    "```\n",
    "You can have a look at [Module 2b](https://dataflowr.github.io/website/modules/2b-automatic-differentiation) to learn more about this approach as well as [MLP from scratch](https://dataflowr.github.io/website/homework/1-mlp-from-scratch/).\n",
    "\n",
    "### Backprop the functional way\n",
    "\n",
    "Here we will implement in `numpy` a different approach mimicking the functional approach of [JAX](https://jax.readthedocs.io/en/latest/index.html) see [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#).\n",
    "\n",
    "Each function will take 2 arguments: one being the input `x` and the other being the parameters `w`. For each function, we build 2 **vjp** functions taking as argument a gradient $\\bbu$ and corresponding to $J_{\\bbf}(\\bbx)$ and $J_{\\bbf}(\\bbw)$ so that these functions return $J_{\\bbf}(\\bbx)^T \\bbu$ and $J_{\\bbf}(\\bbw)^T \\bbu$ respectively. To summarize, for $\\bbx \\in \\mathbb{R}^n$, $\\bbw \\in \\mathbb{R}^d$, and, $\\bbf(\\bbx,\\bbw) \\in \\mathbb{R}^m$,\n",
    "\\begin{align*}\n",
    "{\\bf jvp}_\\bbx(\\bbu) &= J_{\\bbf}(\\bbx)^T \\bbu, \\text{ with } J_{\\bbf}(\\bbx)\\in\\mathbb{R}^{m\\times n}, \\bbu\\in \\mathbb{R}^m\\\\\n",
    "{\\bf jvp}_\\bbw(\\bbu) &= J_{\\bbf}(\\bbw)^T \\bbu, \\text{ with } J_{\\bbf}(\\bbw)\\in\\mathbb{R}^{m\\times d}, \\bbu\\in \\mathbb{R}^m\n",
    "\\end{align*}\n",
    "Then backpropagation is simply done by first computing the gradient of the loss and then composing the **vjp** functions in the right order."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e63c3567",
   "metadata": {},
   "source": [
    "### Example: adding bias\n",
    "\n",
    "We start with the simple example of adding a bias."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2159659c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14316ad6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def add(x, b):\n",
    "    return x + b\n",
    "\n",
    "def add_make_vjp(x, b):\n",
    "    def vjp(u):\n",
    "        return u, u\n",
    "    return vjp\n",
    "\n",
    "add.make_vjp = add_make_vjp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33819c6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.RandomState(0)\n",
    "x = rng.random((30,2)).astype('float32')\n",
    "b_source  = np.array([1.])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25fbe4e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "xb = add(x,b_source)\n",
    "np.allclose(xb, x+b_source)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb3df81e",
   "metadata": {},
   "outputs": [],
   "source": [
    "vjp_add = add.make_vjp(x,b_source)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e7d84da",
   "metadata": {},
   "outputs": [],
   "source": [
    "grad_x, grad_b = vjp_add(rng.random((30,2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "963e2725",
   "metadata": {},
   "outputs": [],
   "source": [
    "grad_x.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f88fd9f6",
   "metadata": {},
   "source": [
    "### Exercise: dot product and squared loss\n",
    "\n",
    "Implement the corresponding vjp functions. (Note: we are abusing notation for the squared loss as the target `y` is not a parameter and should not be updated! Moreover, the vjp_{y_pred} function for the squared loss does not depend on its input `u`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e64a231",
   "metadata": {},
   "outputs": [],
   "source": [
    "def dot(x, W):\n",
    "    return np.dot(x, W)\n",
    "\n",
    "def dot_make_vjp(x, W):\n",
    "    def vjp(u):\n",
    "        #your code\n",
    "    return vjp\n",
    "\n",
    "dot.make_vjp = dot_make_vjp\n",
    "\n",
    "def squared_loss(y_pred, y):\n",
    "    return np.array([np.sum((y - y_pred) ** 2)])\n",
    "\n",
    "def squared_loss_make_vjp(y_pred, y):\n",
    "    def vjp(u):\n",
    "        # your code\n",
    "    return vjp\n",
    "\n",
    "squared_loss.make_vjp = squared_loss_make_vjp"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d3b544d",
   "metadata": {},
   "source": [
    "# Setup\n",
    "\n",
    "As in [02b_linear_reg.ipynb](https://github.com/dataflowr/notebooks/blob/master/Module2/02b_linear_reg.ipynb), our model is:\n",
    "$$\n",
    "y_t = 2x^1_t-3x^2_t+1, \\quad t\\in\\{1,\\dots,30\\}\n",
    "$$\n",
    "\n",
    "Our task is given the 'observations' $(x_t,y_t)_{t\\in\\{1,\\dots,30\\}}$ to recover the weights $w^1=2, w^2=-3$ and the bias $b = 1$.\n",
    "\n",
    "In order to do so, we will solve the following optimization problem:\n",
    "$$\n",
    "\\underset{w^1,w^2,b}{\\operatorname{argmin}} \\sum_{t=1}^{30} \\left(w^1x^1_t+w^2x^2_t+b-y_t\\right)^2\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1bddb56",
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.RandomState(0)\n",
    "# generate random input data\n",
    "x = rng.random((30,2)).astype('float32')\n",
    "# generate labels corresponding to input data x\n",
    "y = np.dot(x, [2., -3.]) + 1.\n",
    "y = np.expand_dims(y, axis=1).astype('float32')\n",
    "w_source = np.array([2., -3.])\n",
    "b_source  = np.array([1.])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93d172f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_feed_forward(y, seed=0):\n",
    "    rng = np.random.RandomState(seed)\n",
    "    funcs = [dot,add,squared_loss]\n",
    "    params = [rng.randn(2,1),rng.randn(1),y]\n",
    "    return funcs, params"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6bc0e93",
   "metadata": {},
   "source": [
    "### Forward pass\n",
    "\n",
    "The following function should take a batch of inputs, functions and their parameters and return the final value or all values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7b1d955",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_chain(x, funcs, params, return_all=False):\n",
    "    all_x = [x]\n",
    "    #\n",
    "    # your code\n",
    "    #"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de5a22c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "funcs, params = create_feed_forward(y=y, seed=0)\n",
    "W, b, _ = params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c76a3272",
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = evaluate_chain(x, funcs, params, return_all=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9b2ae75",
   "metadata": {},
   "source": [
    "### Backward pass\n",
    "\n",
    "The following function should do the forward pass and then the backward pass."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b456409c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def backward_diff_chain(x, funcs, params):\n",
    "    \"\"\"\n",
    "    Reverse-mode differentiation of a chain of computations.\n",
    "\n",
    "    Args:\n",
    "    x: initial input to the chain.\n",
    "    funcs: a list of functions of the form func(x, param).\n",
    "    params: a list of parameters, with len(params) = len(funcs).\n",
    "    Returns:\n",
    "    value, vjp_x, all vjp_params\n",
    "    \"\"\"\n",
    "    # Evaluate the feedforward model and store intermediate computations,\n",
    "    # as they will be needed during the backward pass.\n",
    "    xs = evaluate_chain(x, funcs, params, return_all=True)\n",
    "    K = len(funcs)  # Number of functions.\n",
    "    u = None # the gradient of the loss does not require an input\n",
    "    # List that will contain the Jacobian of each function w.r.t. parameters.\n",
    "    J = [None] * K\n",
    "\n",
    "    #\n",
    "    # your code\n",
    "    #"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6017ecbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss, grad_x, grads = backward_diff_chain(x, funcs, params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b780680e",
   "metadata": {},
   "source": [
    "### Optimizer\n",
    "\n",
    "First compute the update for each parameter and then modify the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbb8e2c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def optim_SGD(grads, learning_rate = 1e-2):\n",
    "    #your code\n",
    "    \n",
    "def update_params(updates, params):\n",
    "    # your code"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d34d6f46",
   "metadata": {},
   "source": [
    "### Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d733022",
   "metadata": {},
   "outputs": [],
   "source": [
    "funcs, params = create_feed_forward(y=y, seed=0)\n",
    "W, b, _ = params\n",
    "for epoch in range(10):\n",
    "    loss, grad_x, grads = backward_diff_chain(x, funcs, params)\n",
    "    print(\"progress:\", \"epoch:\", epoch, \"loss\",loss)\n",
    "    updates = optim_SGD(grads)\n",
    "    params = update_params(updates, params)\n",
    "    \n",
    "# After training\n",
    "print(\"estimation of the parameters:\")\n",
    "print(params[:-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f5f6bb5",
   "metadata": {},
   "source": [
    "### Jax implementation\n",
    "\n",
    "see [linear_regression_jax.ipynb](https://github.com/dataflowr/notebooks/blob/master/Module2/linear_regression_jax.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "335ee8c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import haiku as hk\n",
    "import optax\n",
    "from functools import partial\n",
    "\n",
    "class config:\n",
    "    size_out = 1\n",
    "    w_source = jnp.array(W)\n",
    "    b_source = jnp.array(b)\n",
    "    \n",
    "def _linear(x, config):\n",
    "    return hk.Linear(config.size_out,w_init=hk.initializers.Constant(config.w_source), b_init=hk.initializers.Constant(config.b_source))(x)\n",
    "\n",
    "def mse_loss(y_pred, y_t):\n",
    "    return jax.lax.integer_pow(y_pred - y_t,2).sum()\n",
    "\n",
    "def loss_fn(x_in, y_t, config):\n",
    "    return mse_loss(_linear(x=x_in, config=config),y_t)\n",
    "\n",
    "hk_loss_fn = hk.without_apply_rng(hk.transform(partial(loss_fn, config=config)))\n",
    "params = hk_loss_fn.init(x_in=x,y_t=y,rng=None)\n",
    "loss_fn = hk_loss_fn.apply\n",
    "\n",
    "optimizer = optax.sgd(learning_rate=1e-2)\n",
    "\n",
    "opt_state = optimizer.init(params)\n",
    "for epoch in range(10):\n",
    "    loss, grads = jax.value_and_grad(loss_fn)(params,x_in=x,y_t=y)\n",
    "    print(\"progress:\", \"epoch:\", epoch, \"loss\",loss)\n",
    "    updates, opt_state = optimizer.update(grads, opt_state, params)\n",
    "    params = optax.apply_updates(params, updates)\n",
    "    \n",
    "# After training\n",
    "print(\"estimation of the parameters:\")\n",
    "print(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0149d7e0",
   "metadata": {},
   "source": [
    "### PyTorch implementation\n",
    "\n",
    "See [02b_linear_reg.ipynb](https://github.com/dataflowr/notebooks/blob/master/Module2/02b_linear_reg.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29e71f42",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "dtype = torch.FloatTensor\n",
    "w_init_t = torch.from_numpy(W).type(dtype)\n",
    "b_init_t = torch.from_numpy(b).type(dtype)\n",
    "x_t = torch.from_numpy(x).type(dtype)\n",
    "y_t = torch.from_numpy(y).type(dtype)\n",
    "\n",
    "model = torch.nn.Sequential(torch.nn.Linear(2, 1),)\n",
    "\n",
    "for m in model.children():\n",
    "    m.weight.data = w_init_t.T.clone()\n",
    "    m.bias.data = b_init_t.clone()\n",
    "\n",
    "loss_fn = torch.nn.MSELoss(reduction='sum')\n",
    "\n",
    "model.train()\n",
    "\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)\n",
    "\n",
    "for epoch in range(10):\n",
    "    y_pred = model(x_t)\n",
    "    loss = loss_fn(y_pred, y_t)\n",
    "    print(\"progress:\", \"epoch:\", epoch, \"loss\",loss.item())\n",
    "    # Zero gradients, perform a backward pass, and update the weights.\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "# After training\n",
    "print(\"estimation of the parameters:\")\n",
    "for param in model.parameters():\n",
    "    print(param)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3942541",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jax",
   "language": "python",
   "name": "jax"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
